import gzip
import numpy as np
import matplotlib.pyplot as plt

def load_mnist_images(filename):
    with gzip.open(filename, 'rb') as f:
        data = np.frombuffer(f.read(), np.uint8, offset=16)
    return data.reshape(-1, 28, 28)

# 加载图像
images = load_mnist_images('emnist-gzip/emnist-mnist-test-images-idx3-ubyte.gz')

# 显示前 25 张图像
fig, axes = plt.subplots(5, 5, figsize=(12, 12))
for i, ax in enumerate(axes.flat):
    if i < len(images):
        ax.imshow(images[i], cmap='gray')
        ax.axis('off')
plt.tight_layout()
plt.show()

# 如果您想保存某个特定的图像
# plt.imsave('image_0.png', images[0], cmap='gray')